# -*- coding: utf-8 -*-
# @Time    : 2021/9/6 8:19
# @Author  : 万方名

import os
import torch
import argparse
import pandas as pd
import torch.optim as optim

import torch.utils.data as Data
import torch.nn.functional as F

from torch import nn
from torch.autograd import Variable
from gensim.models import KeyedVectors
from torch.utils.data import DataLoader

os.environ["CUDA_VISIBLE_DEVICES"] = "7"


class MLP(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(200, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )
        self.ce = nn.CrossEntropyLoss()

    def forward(self, x):
        x = x.cuda()
        return self.layers(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer


class TokenClf(nn.Module):
    def __init__(self, device):
        super().__init__()
        dict_path = '/data/word_embedding_cn/cn_0623.vec'
        Word2VecModel = KeyedVectors.load_word2vec_format(dict_path, binary=False)
        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(Word2VecModel.vectors), freeze=True)
        self.mlp = MLP(device)
        self.idx2key = Word2VecModel.index_to_key
        self.key2idx = Word2VecModel.key_to_index


def train(model, device, tokenclf, train_loader, optimizer, epoch):
    model = model.cuda()
    model.train()
    for batch_idx, (indexs, target) in enumerate(train_loader):
        tensor_1 = tokenclf.embedding(Variable(torch.LongTensor(indexs[:, 0])))
        tensor_2 = tokenclf.embedding(Variable(torch.LongTensor(indexs[:, 1])))
        tensor_cat = torch.cat([tensor_1, tensor_2], dim=1)
        tensor_cat = tensor_cat.cuda()
        target = target.cuda()

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        output = model(tensor_cat)

        loss = nn.CrossEntropyLoss()(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(indexs), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def valid(model, device, tokenclf, test_loader):
    model = model.cuda()
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for indexs, target in test_loader:
            tensor_1 = tokenclf.embedding(Variable(torch.LongTensor(indexs[:, 0])))
            tensor_2 = tokenclf.embedding(Variable(torch.LongTensor(indexs[:, 1])))
            tensor_cat = torch.cat([tensor_1, tensor_2], dim=1)
            tensor_cat = tensor_cat.cuda()
            target = target.cuda()
            output = model(tensor_cat)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def del_unk(word_1, word_2, Word2VecModel):
    if word_1 in Word2VecModel and word_2 in Word2VecModel:
        return 1
    return 0


def tran2indexs(word_1, word_2, key2idx):
    index_1 = key2idx[word_1]
    index_2 = key2idx[word_2]
    return [index_1, index_2]


def f2i(f):
    if f == 'AND':
        return 0
    else:
        return 1


def main():
    # settings
    parser = argparse.ArgumentParser(description='PyTorch MLP')
    parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=200, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    device = torch.device("cuda")

    # read data
    data_path = './data/train_data.csv'
    data_df = pd.read_csv(data_path)
    data_df = data_df.sample(frac=1).reset_index(drop=True)

    data_df['LABEL_INDEX'] = data_df['LABEL'].apply(f2i)
    # cut train&test data
    train_data = data_df.sample(frac=0.9, random_state=0, axis=0)
    test_data = data_df[~data_df.index.isin(train_data.index)]
    print(f'共 {len(data_df)} 条数据，train {len(train_data)} 条，test {len(test_data)} 条。')

    # read dict
    dict_path = '/data/word_embedding_cn/cn_0623.vec'
    Word2VecModel = KeyedVectors.load_word2vec_format(dict_path, binary=False)

    key2idx = Word2VecModel.key_to_index

    # data preprocessing
    # 1.del unk
    train_data['if_del'] = train_data.apply(lambda x: del_unk(x['WORD_1'], x['WORD_2'], Word2VecModel), axis=1)
    train_data = train_data.drop(train_data[train_data['if_del'] == 0].index)

    test_data['if_del'] = test_data.apply(lambda x: del_unk(x['WORD_1'], x['WORD_2'], Word2VecModel), axis=1)
    test_data = test_data.drop(test_data[test_data['if_del'] == 0].index)

    # 2.trans 2 vec
    train_data['train_indexs'] = train_data.apply(lambda x: tran2indexs(x['WORD_1'], x['WORD_2'], key2idx), axis=1)
    test_data['test_indexs'] = test_data.apply(lambda x: tran2indexs(x['WORD_1'], x['WORD_2'], key2idx), axis=1)

    # 3.trans 2 tensor
    train_X_feature = list(train_data['train_indexs'].values)
    train_X_tensor = torch.tensor(train_X_feature)
    train_y_feature = list(train_data['LABEL_INDEX'].values)
    train_y_tensor = torch.tensor(train_y_feature)

    test_X_feature = list(test_data['test_indexs'].values)
    test_X_tensor = torch.tensor(test_X_feature)
    test_y_feature = list(test_data['LABEL_INDEX'].values)
    test_y_tensor = torch.tensor(test_y_feature)

    torch_train_dataset = Data.TensorDataset(train_X_tensor, train_y_tensor)
    train_dat = DataLoader(torch_train_dataset, batch_size=args.batch_size, drop_last=True)
    torch_test_dataset = Data.TensorDataset(test_X_tensor, test_y_tensor)
    test_dat = DataLoader(torch_test_dataset, batch_size=args.batch_size, drop_last=True)

    mlp = MLP(device)
    tokenclf = TokenClf(device)

    optimizer = optim.Adadelta(mlp.parameters())
    for epoch in range(1, args.epochs + 1):
        train(mlp, device, tokenclf, train_dat, optimizer, epoch)
        valid(mlp, device, tokenclf, test_dat)

    torch.save(mlp.state_dict(), f"/data/wanfangming_data/query_parser/models/torch_mlp_epoch{args.epochs}.pt")


if __name__ == '__main__':
    main()
